Skip to content

fix: scope get_full_cu_seqlens cache key by device and inference mode#2728

Open
DmCarpe93 wants to merge 8 commits intoNVIDIA:mainfrom
DmCarpe93:fix/get_full_cu_seqlens_cache_key_error
Open

fix: scope get_full_cu_seqlens cache key by device and inference mode#2728
DmCarpe93 wants to merge 8 commits intoNVIDIA:mainfrom
DmCarpe93:fix/get_full_cu_seqlens_cache_key_error

Conversation

@DmCarpe93
Copy link
Copy Markdown

@DmCarpe93 DmCarpe93 commented Mar 3, 2026

Description

Fixed an issue where the cu_seqlen tensor was incorrectly retrieved from the cache.

  • Currently, only (batch_size, max_seqlen) were used as the cache key when retrieving cu_seqlens.
  • This coud result in error especially for Knowledge Distillation training, because teacher and student model can be run on same node.
    • When teacher model run first, cu_seqlens tensor would be created and cached.
    • After that, when student model trains on the same node, the cached cu_seqlens tensor would be used if same (batch_size, max_seqlen) is used.
    • Since cached cu_seqlens tensor from teacher model could have different inference mode and device, it could result in error.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • The cache key for retrieving cu_seqlens was updated from (batch_size, max_seqlen) to include both the device and inference mode.
  • Added testcases for cu_seqlens cache.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Mar 3, 2026

Greptile Summary

This PR fixes a cache-collision bug in get_full_cu_seqlens where two logically distinct tensors — one created under torch.inference_mode() and one in normal training, or one on cuda:0 and another on cuda:1 — could resolve to the same cache key (batch_size, max_seqlen) and incorrectly share a cached cu_seqlens tensor. The fix expands the key to (batch_size, max_seqlen, device, is_inference_mode_enabled()), which is the minimal correct fix for the described Knowledge Distillation scenario.

  • Core fix (utils.py): Two lines — capture torch.is_inference_mode_enabled() and build a 4-tuple cache key. No logic beyond that is touched.
  • Device isolation: tensor.device is always a concrete torch.device (e.g. cuda:0) so equality and hashing are well-defined; no ambiguity from un-indexed "cuda" strings.
  • Inference vs. no-grad: The fix correctly targets torch.inference_mode() (tensors are flagged is_inference=True and cannot leave the context) rather than the weaker torch.no_grad(), which was the actual source of the runtime error in KD training.
  • Cache unboundedness: The global _cu_seqlens_cache dict is still never evicted. The new dimensions add at most 2 × num_devices entries per (batch_size, max_seqlen) pair, which is negligible in practice and was already a pre-existing concern.
  • Tests: Two new well-structured tests cover both isolation axes; the autouse fixture correctly clears the cache before and after each test to prevent cross-test contamination.

Confidence Score: 5/5

Safe to merge — minimal, targeted fix with no regressions and direct test coverage.

No P0/P1 findings. The change is a two-line key extension that directly addresses the described bug. The torch.device and boolean values used as key components are correctly hashable and produce stable equality. Tests cover both scenarios mentioned in the PR description.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/utils.py Cache key for get_full_cu_seqlens extended from (batch_size, max_seqlen) to (batch_size, max_seqlen, device, is_inference) to prevent cross-device and cross-inference-mode cache collisions.
tests/pytorch/attention/test_cu_seqlens_cache.py New test file verifying that the cu_seqlens cache correctly isolates entries by device (multi-GPU) and by inference vs. training mode.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["get_full_cu_seqlens called"] --> B{ONNX export mode?}
    B -- Yes --> C["Skip cache, return directly"]
    B -- No --> D["Read torch.is_inference_mode_enabled"]
    D --> E["Build 4-tuple cache lookup: batch+seqlen+device+inference_flag"]
    E --> F{Found in cache?}
    F -- Yes --> G["Return cached cu_seqlens tensor"]
    F -- No --> H["Create new tensor via torch.arange"]
    H --> I["Store in cache and return"]
Loading

Reviews (7): Last reviewed commit: "Merge branch 'main' into fix/get_full_cu..." | Re-trigger Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@ptrendx ptrendx requested a review from cyanguwa March 3, 2026 18:54
@DmCarpe93
Copy link
Copy Markdown
Author

@cyanguwa When you have a moment, could you please take a look at this PR? Thanks:)

@DmCarpe93
Copy link
Copy Markdown
Author

@cyanguwa This PR is pretty straightforward. Would you mind taking a quick look? Thank you:)

@DmCarpe93
Copy link
Copy Markdown
Author

DmCarpe93 commented Apr 1, 2026

@cyanguwa Hi:) could you look into this PR? thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant